#
# BSD 3-Clause License
#
# Copyright (c) 2017 xxxx
# All rights reserved.
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# ============================================================================
#
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import logging
import numpy as np
import pickle
from enum import Enum
from typing import Optional
import torch
from torch import nn

from detectron2.config import CfgNode
from detectron2.utils.file_io import PathManager

from .vertex_direct_embedder import VertexDirectEmbedder
from .vertex_feature_embedder import VertexFeatureEmbedder


class EmbedderType(Enum):
    """
    Embedder type which defines how vertices are mapped into the embedding space:
     - "vertex_direct": direct vertex embedding
     - "vertex_feature": embedding vertex features
    """

    VERTEX_DIRECT = "vertex_direct"
    VERTEX_FEATURE = "vertex_feature"


def create_embedder(embedder_spec: CfgNode, embedder_dim: int) -> nn.Module:
    """
    Create an embedder based on the provided configuration

    Args:
        embedder_spec (CfgNode): embedder configuration
        embedder_dim (int): embedding space dimensionality
    Return:
        An embedder instance for the specified configuration
        Raises ValueError, in case of unexpected  embedder type
    """
    embedder_type = EmbedderType(embedder_spec.TYPE)
    if embedder_type == EmbedderType.VERTEX_DIRECT:
        embedder = VertexDirectEmbedder(
            num_vertices=embedder_spec.NUM_VERTICES,
            embed_dim=embedder_dim,
        )
        if embedder_spec.INIT_FILE != "":
            embedder.load(embedder_spec.INIT_FILE)
    elif embedder_type == EmbedderType.VERTEX_FEATURE:
        embedder = VertexFeatureEmbedder(
            num_vertices=embedder_spec.NUM_VERTICES,
            feature_dim=embedder_spec.FEATURE_DIM,
            embed_dim=embedder_dim,
            train_features=embedder_spec.FEATURES_TRAINABLE,
        )
        if embedder_spec.INIT_FILE != "":
            embedder.load(embedder_spec.INIT_FILE)
    else:
        raise ValueError(f"Unexpected embedder type {embedder_type}")

    if not embedder_spec.IS_TRAINABLE:
        # pyre-fixme[29]: `Union[nn.Module, torch.Tensor]` is not a function.
        embedder.requires_grad_(False)

    return embedder


class Embedder(nn.Module):
    """
    Embedder module that serves as a container for embedders to use with different
    meshes. Extends Module to automatically save / load state dict.
    """

    DEFAULT_MODEL_CHECKPOINT_PREFIX = "roi_heads.embedder."

    def __init__(self, cfg: CfgNode):
        """
        Initialize mesh embedders. An embedder for mesh `i` is stored in a submodule
        "embedder_{i}".

        Args:
            cfg (CfgNode): configuration options
        """
        super(Embedder, self).__init__()
        self.mesh_names = set()
        embedder_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE
        logger = logging.getLogger(__name__)
        for mesh_name, embedder_spec in cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.items():
            logger.info(f"Adding embedder embedder_{mesh_name} with spec {embedder_spec}")
            # pyre-fixme[29]: `Union[nn.Module, torch.Tensor]` is not a function.
            self.add_module(f"embedder_{mesh_name}", create_embedder(embedder_spec, embedder_dim))
            self.mesh_names.add(mesh_name)
        if cfg.MODEL.WEIGHTS != "":
            self.load_from_model_checkpoint(cfg.MODEL.WEIGHTS)

    def load_from_model_checkpoint(self, fpath: str, prefix: Optional[str] = None):
        if prefix is None:
            prefix = Embedder.DEFAULT_MODEL_CHECKPOINT_PREFIX
        state_dict = None
        if fpath.endswith(".pkl"):
            with PathManager.open(fpath, "rb") as hFile:
                state_dict = pickle.load(hFile, encoding="latin1")  # pyre-ignore[6]
        else:
            with PathManager.open(fpath, "rb") as hFile:
                state_dict = torch.load(hFile, map_location=torch.device("cpu"))
        if state_dict is not None and "model" in state_dict:
            state_dict_local = {}
            for key in state_dict["model"]:
                if key.startswith(prefix):
                    v_key = state_dict["model"][key]
                    if isinstance(v_key, np.ndarray):
                        v_key = torch.from_numpy(v_key)
                    state_dict_local[key[len(prefix) :]] = v_key
            # non-strict loading to finetune on different meshes
            # pyre-fixme[6]: Expected `OrderedDict[typing.Any, typing.Any]` for 1st
            #  param but got `Dict[typing.Any, typing.Any]`.
            self.load_state_dict(state_dict_local, strict=False)

    def forward(self, mesh_name: str) -> torch.Tensor:
        """
        Produce vertex embeddings for the specific mesh; vertex embeddings are
        a tensor of shape [N, D] where:
            N = number of vertices
            D = number of dimensions in the embedding space
        Args:
            mesh_name (str): name of a mesh for which to obtain vertex embeddings
        Return:
            Vertex embeddings, a tensor of shape [N, D]
        """
        return getattr(self, f"embedder_{mesh_name}")()

    def has_embeddings(self, mesh_name: str) -> bool:
        return hasattr(self, f"embedder_{mesh_name}")
